package config

import (
	"errors"
	"github.com/spf13/viper"
	"os"
	"path/filepath"
	"regexp"
	"strings"
	"sync"
)

var (
	// ErrNotFound is key not found.
	ErrNotFound = errors.New("key not found")
	// ErrTypeAssert is type assert error.
	ErrTypeAssert = errors.New("type assert error")
)

type Config interface {
	Load() error              //加载配置文件
	Value(key string) Value   //获取单个key的value
	Scan(v interface{}) error //通过结构体去获取对应的值
}

type Option func(*Configure)

type PlaceholderReplacer func(string) string

type Configure struct {
	cached            sync.Map     //配置缓存
	localPath         string       //本地配置目录
	allowedExtensions []string     //指定后缀列表
	V                 *viper.Viper //viper
}

func NewConfig(opts ...Option) *Configure {
	c := &Configure{
		cached:            sync.Map{},
		V:                 viper.New(),
		allowedExtensions: make([]string, 0),
	}
	for _, o := range opts {
		o(c)
	}
	if len(c.allowedExtensions) == 0 {
		c.allowedExtensions = append(c.allowedExtensions, "yaml")
	}
	return c
}

// 配置文件路径
func WithConfigPath(path string) Option {
	return func(opts *Configure) {
		opts.localPath = path
	}
}

// 添加支持的配置文件后缀
func WithAllowedExtensions(extensions ...string) Option {
	return func(c *Configure) {
		for _, ext := range extensions {
			ext = strings.TrimPrefix(ext, ".") // 去掉可能存在的"."号
			c.allowedExtensions = append(c.allowedExtensions, ext)
		}
	}
}

// 获取对应key的配置
//func (c *Configure) Value(key string) Value {
//	if v, ok := c.cached.Load(key); ok {
//		return v.(Value)
//	}
//	a := NewAtomicValue()
//	str := c.V.GetString(key)
//	if str == "" {
//		return &errValue{err: ErrNotFound}
//	}
//	a.Store(replacePlaceholder(c.V, str))
//	c.cached.Store(key, a)
//	return a
//}

func (c *Configure) Value(key string) Value {
	a := NewAtomicValue()
	if c.V.IsSet(key) {
		//key存在
		val := c.V.Get(key)
		a.Store(val)
		return a
	}
	return &errValue{err: ErrNotFound}
}

func (c *Configure) Scan(v interface{}) error {
	return c.V.Unmarshal(v)
}

func (c *Configure) Load() error {
	//先加载本地配置
	c.V.AutomaticEnv()
	files, err := c.ScanLocalConfig()
	if err != nil {
		return err
	}
	for i, path := range files {
		c.V.SetConfigFile(path)
		if i == 0 {
			readErr := c.V.ReadInConfig()
			if readErr != nil {
				return readErr
			}
		} else {
			mergeErr := c.V.MergeInConfig()
			if mergeErr != nil {
				return mergeErr
			}
		}
	}
	c.LoadALLValue()
	return nil
}

func (c *Configure) LoadALLValue() {
	// 替换占位符
	for _, key := range c.V.AllKeys() {
		value := c.V.GetString(key)
		newValue := replacePlaceholder(c.V, value)
		c.V.Set(key, newValue)
	}
}

/*
扫描本地配置
*/
func (c *Configure) ScanLocalConfig() ([]string, error) {
	var filePaths []string
	err := filepath.Walk(c.localPath, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}
		// 如果是文件并且后缀名符合条件，则将路径添加到切片中
		if !info.IsDir() && hasAllowedExtension(path, c.allowedExtensions) {
			filePaths = append(filePaths, path)
		}
		return nil
	})
	if err != nil {
		return nil, err
	}
	return filePaths, nil
}

// 判断文件后缀名是否符合条件
func hasAllowedExtension(filePath string, allowedExtensions []string) bool {
	ext := strings.TrimPrefix(strings.ToLower(filepath.Ext(filePath)), ".")
	// 预处理allowedExtensions，将"yml"和"yaml"视为等价
	normalizedExtensions := make(map[string]struct{})
	for _, allowedExt := range allowedExtensions {
		allowedExt = strings.ToLower(allowedExt)
		if allowedExt == "yml" || allowedExt == "yaml" {
			normalizedExtensions["yml"] = struct{}{}
			normalizedExtensions["yaml"] = struct{}{}
		} else {
			normalizedExtensions[allowedExt] = struct{}{}
		}
	}
	// 检查文件扩展名是否在预处理后的扩展名列表中
	if _, exists := normalizedExtensions[ext]; exists {
		return true
	}
	return false
}

/*
获取环境变量替换占位符
*/
func replacePlaceholder(v *viper.Viper, config string) string {
	placeholderRegex := regexp.MustCompile(`\${([^:{}]+)(?::([^{}]+))?}`)
	// 提取占位符和默认值
	return placeholderRegex.ReplaceAllStringFunc(config, func(match string) string {
		matches := placeholderRegex.FindStringSubmatch(match)
		placeholder := matches[1]
		defaultValue := ""
		if len(matches) > 2 {
			defaultValue = matches[2] // 获取默认值
		}
		// 如果 viper 配置中设置了该值，则使用之；否则使用默认值
		if v.IsSet(placeholder) {
			return v.GetString(placeholder)
		} else {
			return defaultValue
		}
	})
}
